安装torch |
您所在的位置:网站首页 › python 安装torch库 › 安装torch |
1. 在安装前要检查电脑的上的torch和cuda版本 检查torch版本: import torch; print(torch.__version__)检查cuda版本: import torch; print(torch.version.cuda)2.下载whl进行安装在网址torch_geometric库上找到与自己torch、cuda对应的版本: 点击进入whl下载页面,找到对应的操作系统、python版本进行下载: 其次就是对其进行安装,安装顺序为: 1.torch-scatter 2.torch-sparse 3.torch-cluster 4.torch-spline-conv 5.torch-geometric 其中1-4的步骤是利用离线的安装包在本地进行安装,命令为 pip install +本地的路径+文件名称,最后一个安装包是利用镜像源下载,命令为 pip install torch-geometric +镜像源;到此本次的安装就全部结束。 Ps: 1. 镜像源: -i https://pypi.doubanio.com/simple https://mirrors.aliyun.com/pypi/simple/ https://pypi.tuna.tsinghua.edu.cn/simple2. 在安装完毕后可以用下面的这段代码进行测试一下 import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(GCNConv, self).__init__(aggr='add') self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) x = self.lin(x) return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) def message(self, x_j, edge_index, size): row, col = edge_index deg = degree(row, size[0], dtype=x_j.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] return norm.view(-1, 1) * x_j def update(self, aggr_out): return aggr_out if __name__ == '__main__': # 假设图节点属性向量的维度为16,图卷积出来的节点特征表示向量维度为32 conv = GCNConv(16, 32) x = torch.randn(5, 16) print(x.shape) edge_index = [ [0, 1, 1, 2, 1, 3], [1, 0, 2, 1, 3, 1] ] edge_index = torch.tensor(edge_index, dtype=torch.long) output = conv(x, edge_index) print(output.shape) print(output.data)3.torch_geometric库官方参考文档网站从网站上可以查看所有torch_geometric库中函数的定义说明。 |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |